use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use tracing::{debug, error, trace};

use crate::agents::extension::ExtensionError;

#[derive(Clone)]
pub struct OsvChecker {
    client: reqwest::Client,
    endpoint: Url,
}

impl OsvChecker {
    /// Constructs a checker. Honors OSV_ENDPOINT env var if present.
    pub fn new() -> Result<Self, Box<ExtensionError>> {
        let client = http_client().map_err(Box::new)?;
        let endpoint = std::env::var("OSV_ENDPOINT")
            .ok()
            .and_then(|s| Url::parse(&s).ok())
            .unwrap_or_else(|| Url::parse(DEFAULT_OSV_ENDPOINT).expect("valid default OSV url"));
        Ok(Self { client, endpoint })
    }

    /// Constructs with a custom endpoint (handy for tests).
    pub fn with_endpoint(endpoint: Url) -> Result<Self, Box<ExtensionError>> {
        let client = http_client().map_err(Box::new)?;
        Ok(Self { client, endpoint })
    }

    /// Query OSV and **fail** if any MAL-* advisories are found.
    /// - `ecosystem`: e.g., "npm", "PyPI"
    /// - `version`: if `None`, checks by name only.
    pub async fn deny_if_malicious(
        &self,
        name: &str,
        ecosystem: &str,
        version: Option<&str>,
    ) -> Result<(), ExtensionError> {
        deny_if_malicious_impl(&self.client, &self.endpoint, name, ecosystem, version).await
    }
}

/// Convenience: infer ecosystem from command token + parse first package arg.
/// - ends_with("npx") → npm
/// - ends_with("uvx") → PyPI
///   unknown commands → skip (fail open)
pub async fn deny_if_malicious_cmd_args(cmd: &str, args: &[String]) -> Result<(), ExtensionError> {
    let ecosystem = if cmd.ends_with("uvx") {
        "PyPI"
    } else if cmd.ends_with("npx") {
        "npm"
    } else {
        debug!(%cmd, ?args, "Unknown ecosystem for command; skipping OSV check (fail open).");
        return Ok(());
    };

    if let Some((name, version)) = parse_first_package_arg(ecosystem, args) {
        OsvChecker::new()
            .map_err(|e| *e)?
            .deny_if_malicious(&name, ecosystem, version.as_deref())
            .await?;
    } else {
        debug!(%cmd, ?args, "No package token found; skipping OSV check.");
    }

    Ok(())
}

/// Direct call without command inference.
pub async fn deny_if_malicious(
    name: &str,
    ecosystem: &str,
    version: Option<&str>,
) -> Result<(), ExtensionError> {
    OsvChecker::new()
        .map_err(|e| *e)?
        .deny_if_malicious(name, ecosystem, version)
        .await
}

fn parse_first_package_arg(ecosystem: &str, args: &[String]) -> Option<(String, Option<String>)> {
    let is_flag = |s: &str| s.starts_with('-');
    let token = args
        .iter()
        .find(|a| !is_flag(a.as_str()))?
        .trim()
        .to_string();
    if token.is_empty() {
        return None;
    }
    match ecosystem {
        "npm" => parse_npm_token(&token),
        "PyPI" => parse_pypi_token(&token),
        _ => None,
    }
}

fn parse_npm_token(token: &str) -> Option<(String, Option<String>)> {
    // Handles:
    //   react@18.3.1
    //   @scope/pkg@1.2.3   (split at the LAST '@')
    //   eslint              (no version)
    if token.starts_with('@') {
        if let Some(idx) = token.rfind('@') {
            if idx > 0 {
                let (name, ver) = token.split_at(idx);
                let ver = ver.trim_start_matches('@');
                if !ver.is_empty() && ver != "latest" {
                    return Some((name.to_string(), Some(ver.to_string())));
                } else {
                    return Some((name.to_string(), None));
                }
            }
        }
        Some((token.to_string(), None))
    } else if let Some(idx) = token.find('@') {
        let (name, ver) = token.split_at(idx);
        let ver = ver.trim_start_matches('@');
        if !name.is_empty() {
            if !ver.is_empty() && ver != "latest" {
                return Some((name.to_string(), Some(ver.to_string())));
            } else {
                return Some((name.to_string(), None));
            }
        }
        None
    } else {
        Some((token.to_string(), None))
    }
}

fn parse_pypi_token(token: &str) -> Option<(String, Option<String>)> {
    // Accept exact pins:
    //   package==1.2.3
    //   package[extra]==1.2.3
    // Treat "latest" as None. Ignore other specifiers (>=, <=, ~=, !=) for pinning.
    let lowered = token.to_ascii_lowercase();
    if let Some(idx) = lowered.find("==") {
        let (name, ver) = token.split_at(idx);
        let ver = ver.trim_start_matches('=').trim_start_matches('=');
        let name = name.trim();
        if name.is_empty() {
            return None;
        }
        if ver.is_empty() || ver.eq_ignore_ascii_case("latest") {
            return Some((name.to_string(), None));
        }
        return Some((name.to_string(), Some(ver.to_string())));
    }
    Some((token.to_string(), None))
}

const DEFAULT_OSV_ENDPOINT: &str = "https://api.osv.dev/v1/query";

#[derive(Serialize)]
struct QueryReq<'a> {
    #[serde(skip_serializing_if = "Option::is_none")]
    version: Option<&'a str>,
    package: Package<'a>,
    #[serde(skip_serializing_if = "Option::is_none")]
    page_token: Option<String>,
}

#[derive(Serialize)]
struct Package<'a> {
    name: &'a str,
    ecosystem: &'a str,
    #[serde(skip_serializing_if = "Option::is_none")]
    purl: Option<&'a str>,
}

#[derive(Deserialize)]
struct QueryResp {
    #[serde(default)]
    vulns: Vec<Vuln>,
    #[serde(default)]
    next_page_token: Option<String>,
}

#[derive(Deserialize)]
struct Vuln {
    id: String,
    #[serde(default)]
    summary: String,
}

async fn deny_if_malicious_impl(
    client: &reqwest::Client,
    endpoint: &Url,
    name: &str,
    ecosystem: &str,
    version: Option<&str>,
) -> Result<(), ExtensionError> {
    debug!(name, ecosystem, ?version, "OSV query starting");
    let mut page_token: Option<String> = None;
    let mut mal: Vec<Vuln> = Vec::new();

    loop {
        let body = QueryReq {
            version,
            package: Package {
                name,
                ecosystem,
                purl: None,
            },
            page_token: page_token.clone(),
        };
        trace!(?body.page_token, "OSV page");

        let resp = match client.post(endpoint.clone()).json(&body).send().await {
            Ok(r) => r,
            Err(e) => {
                error!(%e, name, ecosystem, ?version, "OSV request failed; failing open.");
                return Ok(());
            }
        };

        let resp = match resp.error_for_status() {
            Ok(r) => r,
            Err(e) => {
                error!(%e, name, ecosystem, ?version, "OSV HTTP error; failing open.");
                return Ok(());
            }
        };

        let payload: QueryResp = match resp.json().await {
            Ok(p) => p,
            Err(e) => {
                error!(%e, name, ecosystem, ?version, "OSV JSON parse error; failing open.");
                return Ok(());
            }
        };

        mal.extend(
            payload
                .vulns
                .into_iter()
                .filter(|v| v.id.starts_with("MAL-")),
        );

        match payload.next_page_token {
            Some(tok) if !tok.is_empty() => page_token = Some(tok),
            _ => break,
        }
    }

    if !mal.is_empty() {
        let ver = version.unwrap_or("<any>");
        let details = mal
            .into_iter()
            .map(|v| {
                if v.summary.is_empty() {
                    v.id
                } else {
                    format!("{} — {}", v.id, v.summary)
                }
            })
            .collect::<Vec<_>>()
            .join("; ");
        error!(name, ecosystem, version=%ver, %details, "Blocked malicious package via OSV MAL-*.");
        return Err(ExtensionError::ConfigError(format!(
            "Blocked malicious package: {name}@{ver} ({ecosystem}). OSV MAL advisories: {details}"
        )));
    }

    debug!(name, ecosystem, ?version, "OSV: no MAL advisories.");
    Ok(())
}

#[allow(clippy::result_large_err)]
fn http_client() -> Result<reqwest::Client, ExtensionError> {
    let mut headers = HeaderMap::new();
    headers.insert(
        USER_AGENT,
        HeaderValue::from_static("goose-osv-check/1.1 (+https://osv.dev)"),
    );
    reqwest::Client::builder()
        .default_headers(headers)
        .timeout(std::time::Duration::from_secs(10))
        .build()
        .map_err(|e| ExtensionError::SetupError(format!("failed to build HTTP client: {e}")))
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;
    use serial_test;
    use tokio;
    use wiremock::matchers::{method, path};
    use wiremock::{Mock, MockServer, ResponseTemplate};

    fn checker_for(server: &MockServer) -> OsvChecker {
        let url = Url::parse(&format!("{}/v1/query", server.uri())).unwrap();
        OsvChecker::with_endpoint(url).unwrap()
    }

    // Helper to temporarily set an environment variable and restore it on drop
    struct TempEnvVar {
        key: String,
        original: Option<String>,
    }

    impl TempEnvVar {
        fn set(key: &str, value: &str) -> Self {
            let original = std::env::var(key).ok();
            std::env::set_var(key, value);
            Self {
                key: key.to_string(),
                original,
            }
        }
    }

    impl Drop for TempEnvVar {
        fn drop(&mut self) {
            match &self.original {
                Some(val) => std::env::set_var(&self.key, val),
                None => std::env::remove_var(&self.key),
            }
        }
    }

    #[tokio::test]
    async fn allows_clean_package() {
        let server = MockServer::start().await;
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "vulns": [],
                "next_page_token": null
            })))
            .mount(&server)
            .await;

        let c = checker_for(&server);
        let res = c
            .deny_if_malicious("some_clean_package", "PyPI", None)
            .await;
        assert!(res.is_ok());
    }

    #[tokio::test]
    async fn blocks_malicious_package() {
        let server = MockServer::start().await;
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "vulns": [ { "id": "MAL-1234", "summary": "Malicious package" } ],
                "next_page_token": null
            })))
            .mount(&server)
            .await;

        let c = checker_for(&server);
        let res = c
            .deny_if_malicious("bad_package", "PyPI", Some("1.0.0"))
            .await;
        assert!(res.is_err());
        let msg = format!("{:?}", res.unwrap_err());
        assert!(msg.contains("Blocked malicious package"));
        assert!(msg.contains("MAL-1234"));
    }

    #[tokio::test]
    #[serial_test::serial]
    async fn cmd_args_pypi_clean() {
        let server = MockServer::start().await;
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "vulns": [],
                "next_page_token": null
            })))
            .mount(&server)
            .await;

        // Use env var so OsvChecker::new() picks it up
        let _env = TempEnvVar::set("OSV_ENDPOINT", &format!("{}/v1/query", server.uri()));
        let args = vec!["some_clean_package==1.2.3".to_string()];
        let res = deny_if_malicious_cmd_args("uvx", &args).await;
        assert!(res.is_ok());
    }

    #[tokio::test]
    #[serial_test::serial]
    async fn cmd_args_npm_scoped_malicious() {
        let server = MockServer::start().await;
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "vulns": [ { "id": "MAL-9999", "summary": "Malicious npm package" } ],
                "next_page_token": null
            })))
            .mount(&server)
            .await;

        let _env = TempEnvVar::set("OSV_ENDPOINT", &format!("{}/v1/query", server.uri()));
        let args = vec!["@scope/pkg@2.0.0".to_string()];
        let res = deny_if_malicious_cmd_args("npx", &args).await;
        assert!(res.is_err());
        let msg = format!("{:?}", res.unwrap_err());
        assert!(msg.contains("Blocked malicious package"));
        assert!(msg.contains("MAL-9999"));
    }

    #[tokio::test]
    #[serial_test::serial]
    async fn cmd_args_skip_flags_then_parse() {
        let server = MockServer::start().await;
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "vulns": [],
                "next_page_token": null
            })))
            .mount(&server)
            .await;

        let _env = TempEnvVar::set("OSV_ENDPOINT", &format!("{}/v1/query", server.uri()));
        let args = vec![
            "--dry-run".into(),
            "-y".into(),
            "some_clean_package@1.2.3".into(),
        ];
        let res = deny_if_malicious_cmd_args("npx", &args).await;
        assert!(res.is_ok());
    }

    #[tokio::test]
    async fn pagination_works() {
        let server = MockServer::start().await;
        // 1st page: no vulns, but has next
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "vulns": [],
                "next_page_token": "page-2"
            })))
            .up_to_n_times(1)
            .mount(&server)
            .await;

        // 2nd page: MAL hit
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
                "vulns": [ { "id": "MAL-4242", "summary": "Second page hit" } ],
                "next_page_token": null
            })))
            .mount(&server)
            .await;

        let c = checker_for(&server);
        let res = c.deny_if_malicious("pkg", "npm", None).await;
        assert!(res.is_err());
        let msg = format!("{:?}", res.unwrap_err());
        assert!(msg.contains("MAL-4242"));
    }

    #[tokio::test]
    async fn fail_open_on_http_error() {
        let server = MockServer::start().await;
        Mock::given(method("POST"))
            .and(path("/v1/query"))
            .respond_with(ResponseTemplate::new(500))
            .mount(&server)
            .await;

        let c = checker_for(&server);
        let res = c.deny_if_malicious("pkg", "npm", None).await;
        assert!(res.is_ok(), "should fail-open on HTTP errors");
    }

    #[tokio::test]
    async fn unknown_command_is_skipped() {
        let args = vec!["whatever@1.0.0".into()];
        // no mock server: we shouldn't call OSV at all
        let res = deny_if_malicious_cmd_args("some-other-bin", &args).await;
        assert!(res.is_ok());
    }

    #[test]
    fn parse_npm_scoped_with_version() {
        assert_eq!(
            super::parse_npm_token("@scope/pkg@1.2.3"),
            Some(("@scope/pkg".into(), Some("1.2.3".into())))
        );
    }

    #[test]
    fn parse_npm_unscoped_latest_is_none() {
        assert_eq!(
            super::parse_npm_token("react@latest"),
            Some(("react".into(), None))
        );
    }

    #[test]
    fn parse_pypi_exact_pin_and_latest() {
        assert_eq!(
            super::parse_pypi_token("requests==2.32.3"),
            Some(("requests".into(), Some("2.32.3".into())))
        );
        assert_eq!(
            super::parse_pypi_token("requests==latest"),
            Some(("requests".into(), None))
        );
    }
}
